package com.dec.kks.etl.model;

import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.RFormula;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import java.io.IOException;

public class DCTModelTrainMain {

    public static void main(String[] args) throws IOException {
        System.setProperty("hadoop.home.dir","/home/hdfs/bigdata/hadoop-2.7.4");
        SparkSession spark = SparkSession
                .builder()
                .appName("coid model train")
                .master("local")
                .getOrCreate();

        spark.sparkContext().setLogLevel("WARN");

        String  pathcsv = "/home/hdfs/soft/dec-project/dec-kks-etl/data/coil.csv";
        String modelPath = "model/regress/dct";
        Dataset<Row> data = spark.read().format("csv")
                .option("sep", ",")
                .option("inferSchema", "true")
                .option("header", "true")
                .load(pathcsv);

        RFormula formula = new RFormula()
                .setFormula("rise ~ ecurrent + flow")
                .setFeaturesCol("features")
                .setLabelCol("label");

        DecisionTreeRegressor dt = new DecisionTreeRegressor()
                .setFeaturesCol("features")
                .setLabelCol("label")
                .setImpurity("variance")
                .setMaxBins(15)
                .setMaxDepth(4)
                .setMinInstancesPerNode(2);

        Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{formula, dt});

        PipelineModel model = pipeline.fit(data);

        model.write().overwrite().save(modelPath);

        model = PipelineModel.load(modelPath);

        Dataset<Row> predictions = model.transform(data);

        predictions.select("label", "features").show(5);

        RegressionEvaluator evaluator = new RegressionEvaluator()
                .setLabelCol("label")
                .setPredictionCol("prediction")
                .setMetricName("rmse");
        double rmse = evaluator.evaluate(predictions);
        System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);

        DecisionTreeRegressionModel treeModel =
                (DecisionTreeRegressionModel) (model.stages()[1]);
        System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());

        spark.stop();
    }

}
